"""
包含2个部分：VAE解耦，RL决策。
VAE发现样本中的关联数据。
RL对VAE学到的关联进行直接学习，并作出决策，选择VAE学到的样本。
"""
import argparse
from collections import defaultdict

from torch import optim
from torch.utils.data import DataLoader
import wandb
from tqdm import trange

from disvae.models.anneal import *
from disvae.models.losses import BtcvaeLoss, get_loss_s
from exps.models import *
from exps.visualize import *
from exps.patterns import data_generator
import numpy as np
from utils.helpers import set_seed
import torch.distributions as D

hyperparameter_defaults = dict(
    batch_size=256,
    learning_rate=5e-4,
    epochs=2041,
    beta=50,
    dimension=6,
    img_id=0,
    random_seed=210,
    file=__file__.split('/')[-1]
)

parser = argparse.ArgumentParser()
for key, value in hyperparameter_defaults.items():
    parser.add_argument(f'--{key}', default=value, type=type(value))
args = parser.parse_args()
config = args
wandb.init(project="exps", config=args, )
set_seed(config.random_seed)


def cal_log_prob(outputs):
    labels = []
    log_probs = []
    for i in range(len(outputs)):
        p = F.softmax(outputs[i].data, 1)
        m = D.categorical.Categorical(p)
        preds = m.sample()
        log_prob = F.cross_entropy(outputs[i], preds, reduction='none')
        labels.append(preds)
        log_probs.append(log_prob)
    return torch.stack(labels, 1), torch.stack(log_probs, 1).sum(1, True)


def train(dl, model, loss_f, epochs):
    opt = optim.AdamW(model.parameters(), config.learning_rate)
    opt.add_param_group({'params': picker.parameters()})
    for e in trange(epochs):
        storer = defaultdict(list)
        for img, label in dl:
            bs = len(img)
            img = img.view(-1, 1, 64, 64).cuda()
            recon_batch, latent_dist, latent_sample = model(img)

            for _ in range(10):
                outputs = picker(latent_sample)
                labels, log_probs = cal_log_prob(outputs)  # 根据z选择样本
                target_imgs = dataset.o_imgs[labels[:, 0], labels[:, 1]].reshape(bs, 1, 64, 64).cuda()
                r = - F.binary_cross_entropy(target_imgs, img, reduction='none').sum([2, 3])
                polocy = (log_probs * r).mean()
                polocy.backward(retain_graph=True)

            loss = loss_f(target_imgs, recon_batch, latent_dist, True, storer=storer)
            opt.zero_grad()
            loss.backward()
            opt.step()

            storer['r'].append(r.mean().item())

        for k, v in storer.items():
            if isinstance(v, list):
                storer[k] = np.mean(v)

        if (e + 1) % 40 == 0:
            storer['recon'] = wandb.Image(recon_batch[0, 0])
            storer['img'] = wandb.Image(img[0, 0])

            kl_loss = torch.Tensor([v for k, v in storer.items() if 'kl_' in k])
            _, index = kl_loss.sort(descending=True)

            preds, targets = get_preds(model, dl)
            fig = gt_vs_latent(preds[:, index[:3]].cpu(), targets.cpu())
            storer['correlated'] = wandb.Image(fig)

        storer['epoch'] = e
        wandb.log(storer, sync=False)
        plt.close()
    return storer


# generate data
img_id = config.img_id
img = torch.load(f'patterns/{img_id}.pat')
# img = data_generator.rotate(img,45)
dataset = data_generator.Translation(img)

epochs = config.epochs
dim = config.dimension

dl = DataLoader(dataset, pin_memory=True,
                num_workers=0, batch_size=config.batch_size, shuffle=True)


class Picker(nn.Module):
    '''
    输入是当前标签与编码的偏置。输出目标样本的标签。
    '''

    def __init__(self, label_size, latent_dim):
        super().__init__()
        self.label_size = label_size
        self.latent_dim = latent_dim
        self.hidden = 128
        label_len = np.sum(self.label_size)
        self.net = nn.Sequential(nn.Linear(latent_dim, self.hidden), nn.ReLU(),
                                 nn.Linear(self.hidden, self.hidden), nn.ReLU(),
                                 nn.Linear(self.hidden, self.hidden), nn.ReLU(),
                                 nn.Linear(self.hidden, label_len))

    def forward(self, x):
        x = self.net(x)
        return x.split(self.label_size, 1)


vae = VAE((1, 64, 64), dim)
picker = Picker(dataset.lat_sizes.tolist(), dim).cuda()
vae.cuda()
vae.train()
anneal = get_anneal('monotonic', len(dl) * epochs, 1, 0.2)
loss_f = get_loss_s('betaH', anneal, config.beta, len(dl.dataset))

storer = train(dl, vae, loss_f, epochs)

kl_loss = torch.Tensor([v for k, v in storer.items() if 'kl_' in k])
_, index = kl_loss.sort(descending=True)
print(index)
vae.eval()
vae.cpu()

preds, targets = get_preds(vae, dl)
points = preds[:, index[:3]].cpu()
storer['point_cloud'] = wandb.Object3D(points.numpy())

fig = plt_sample_traversal(None, vae, 7, index, r=2)
storer['traversal'] = wandb.Image(fig)

wandb.log(storer)
plt.show()

torch.save(vae.state_dict(), 'tmp.pt')
